In [1]:
#imports
import os
import re
import glob
import math
import cv2
import gc
import csv
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from transunet import TransUNet
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from tensorflow import keras
from tensorflow.keras import layers
# List available GPUs
gpus = tf.config.list_physical_devices('GPU')
print("GPUs: ", gpus)
if gpus:
print("TensorFlow is using the GPU.")
else:
print("TensorFlow is not using the GPU.")
2024-12-04 19:46:47.325076: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
GPUs: [] TensorFlow is not using the GPU.
/home/des/anaconda3/lib/python3.10/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: TensorFlow Addons (TFA) has ended development and introduction of new features. TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024. Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). For more information see: https://github.com/tensorflow/addons/issues/2807 warnings.warn(
In [2]:
def rle_to_binary(rle, shape):
"""
Converts a RLE (run length encoding) to a binary mask.
"""
# Initialize a flat mask with zeros
mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
if rle == '' or rle == '0': # Handle empty RLE
return mask.reshape(shape, order='C')
# Decode RLE into mask
rle_numbers = list(map(int, rle.split()))
for i in range(0, len(rle_numbers), 2):
start = rle_numbers[i] - 1 # Convert to zero-indexed
length = rle_numbers[i + 1]
mask[start:start + length] = 1
# Reshape flat mask into 2D
return mask.reshape(shape, order='C')
def custom_generator(gdf, dir, batch_size, target_size=(224, 224), test_mode=False):
"""
Custom data generator that dynamically aligns images and masks using RLE decoding.
Parameters:
gdf (GroupBy): Grouped dataframe containing image IDs and RLEs.
dir (str): Root directory of the dataset.
batch_size (int): Number of samples per batch.
target_size (tuple): Target size for resizing (default=(224, 224)).
test_mode (bool): If True, yields one image and mask at a time.
"""
ids = list(gdf.groups.keys())
dir2 = 'train'
while True:
sample_ids = np.random.choice(ids, size=batch_size, replace=False)
images, masks = [], []
for id_num in sample_ids:
# Get the dataframe rows for the current image
img_rows = gdf.get_group(id_num)
rle_list = img_rows['segmentation'].tolist()
# Construct the file path for the image
sections = id_num.split('_')
case = sections[0]
day = sections[0] + '_' + sections[1]
slice_id = sections[2] + '_' + sections[3]
pattern = os.path.join(dir, dir2, case, day, "scans", f"{slice_id}*.png")
filelist = glob.glob(pattern)
if filelist:
file = filelist[0]
image = cv2.imread(file, cv2.IMREAD_COLOR)
if image is None:
print(f"Image not found: {file}")
continue # Skip if the image is missing
# Original shape of the image
original_shape = image.shape[:2]
# Resize the image
resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
# Decode and resize the masks
mask = np.zeros((target_size[0], target_size[1], len(rle_list)), dtype=np.uint8)
for i, rle in enumerate(rle_list):
if rle != '0': # Check if the RLE is valid
decoded_mask = rle_to_binary(rle, original_shape)
resized_mask = cv2.resize(decoded_mask, target_size, interpolation=cv2.INTER_NEAREST)
mask[:, :, i] = resized_mask
if test_mode:
# Return individual samples in test mode
yield resized_image[np.newaxis], mask[np.newaxis], pattern
else:
images.append(resized_image)
masks.append(mask)
if not test_mode:
x = np.array(images)
y = np.array(masks)
yield x, y
def dice_coef(y_true, y_pred, smooth=1e-6):
# Ensure consistent data types
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
y_true_f = tf.keras.backend.flatten(y_true)
y_pred_f = tf.keras.backend.flatten(y_pred)
intersection = tf.reduce_sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
def dice_loss(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
return 1 - dice_coef(y_true, y_pred)
In [3]:
dir = './Dataset'
In [4]:
df = pd.read_csv(os.path.join('.', dir, 'train.csv'))
df['segmentation'] = df['segmentation'].fillna('0')
train_ids, temp_ids = train_test_split(df.id.unique(), test_size=0.1, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)
train_grouped_df = df[df.id.isin(train_ids)].groupby('id')
val_grouped_df = df[df.id.isin(val_ids)].groupby('id')
test_grouped_df = df[df.id.isin(test_ids)].groupby('id')
batch_size = 24
target_size = 224
epochs = 8
# steps per epoch is typically train length / batch size to use all training examples
train_steps_per_epoch = math.ceil(len(train_ids) / batch_size)
val_steps_per_epoch = math.ceil(len(val_ids) / batch_size)
test_steps_per_epoch = math.ceil(len(test_ids) / batch_size)
# create the training and validation datagens
train_generator = custom_generator(train_grouped_df, dir, batch_size, (target_size, target_size))
val_generator = custom_generator(val_grouped_df, dir, batch_size, (target_size, target_size))
test_generator = custom_generator(test_grouped_df, dir, batch_size, (target_size, target_size), test_mode=True)
In [5]:
loading = True
if loading:
model = TransUNet(image_size=224, pretrain=False)
model.load_weights('./impmodels/model_weights.h5')
model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy'])
else:
# create the optimizer and learning rate scheduler
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate = 1e-3,
# decay_steps=train_steps_per_epoch * epochs,
decay_steps=epochs+2,
alpha=1e-2
)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
# create the U-net neural network
model = TransUNet(image_size=target_size, freeze_enc_cnn=False, pretrain=True)
model.compile(optimizer=optimizer, loss=dice_loss, metrics=['accuracy'])
# set up model checkpoints and early stopping
checkpoints_path = os.path.join('Checkpoints', 'model_weights.h5')
model_checkpoint = ModelCheckpoint(filepath=checkpoints_path, save_best_only=True, monitor='val_loss')
early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=8)
# log the training to a .csv for reference
csv_logger = CSVLogger('training_log.csv', append=True)
history = model.fit(train_generator, validation_data=val_generator, steps_per_epoch=train_steps_per_epoch, validation_steps=val_steps_per_epoch, epochs=epochs, callbacks=[model_checkpoint, early_stopping, csv_logger])
2024-12-04 19:46:49.954136: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-12-04 19:46:49.956644: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
In [6]:
preds = []
ground_truths = []
num_samples = 24
# Generate predictions and ground truths
for i in range(num_samples):
# Fetch a batch from the test generator
batch = next(test_generator)
image, mask, path = batch # Assuming generator returns (images, masks)
preds.append(model.predict(image)) # Predict using the model
ground_truths.append(mask)
best_threshold = 0.99
# Apply the best threshold to all predictions
final_preds = [(pred >= best_threshold).astype(int) for pred in preds]
# Compute Dice loss for each prediction
for i in range(len(final_preds)):
loss = dice_loss(ground_truths[i], final_preds[i]) # Assuming `dice_loss` is defined
print(f"Image {i + 1}: Dice Loss = {loss:.4f}")
1/1 [==============================] - 3s 3s/step 1/1 [==============================] - 0s 407ms/step 1/1 [==============================] - 0s 454ms/step 1/1 [==============================] - 0s 431ms/step 1/1 [==============================] - 0s 387ms/step 1/1 [==============================] - 0s 393ms/step 1/1 [==============================] - 0s 337ms/step 1/1 [==============================] - 0s 376ms/step 1/1 [==============================] - 0s 462ms/step 1/1 [==============================] - 0s 383ms/step 1/1 [==============================] - 0s 413ms/step 1/1 [==============================] - 0s 377ms/step 1/1 [==============================] - 0s 373ms/step 1/1 [==============================] - 1s 737ms/step 1/1 [==============================] - 0s 316ms/step 1/1 [==============================] - 0s 439ms/step 1/1 [==============================] - 0s 408ms/step 1/1 [==============================] - 0s 436ms/step 1/1 [==============================] - 1s 512ms/step 1/1 [==============================] - 0s 365ms/step 1/1 [==============================] - 0s 345ms/step 1/1 [==============================] - 0s 373ms/step 1/1 [==============================] - 0s 480ms/step 1/1 [==============================] - 0s 366ms/step Image 1: Dice Loss = 0.0626 Image 2: Dice Loss = 0.0000 Image 3: Dice Loss = 1.0000 Image 4: Dice Loss = 0.1234 Image 5: Dice Loss = 0.0000 Image 6: Dice Loss = 0.0000 Image 7: Dice Loss = 0.0000 Image 8: Dice Loss = 0.1158 Image 9: Dice Loss = 0.0000 Image 10: Dice Loss = 0.1349 Image 11: Dice Loss = 0.0000 Image 12: Dice Loss = 0.0708 Image 13: Dice Loss = 0.0000 Image 14: Dice Loss = 0.1931 Image 15: Dice Loss = 0.0000 Image 16: Dice Loss = 0.0601 Image 17: Dice Loss = 0.0468 Image 18: Dice Loss = 0.0000 Image 19: Dice Loss = 0.0000 Image 20: Dice Loss = 0.0000 Image 21: Dice Loss = 0.0000 Image 22: Dice Loss = 0.0000 Image 23: Dice Loss = 0.0000 Image 24: Dice Loss = 0.0000
In [7]:
def visualize_predictions(generator, model, num_samples=8, target_size=(224, 224)):
"""
Visualize predictions vs. ground truths overlaid on original images.
Parameters:
generator (generator): Data generator for test data.
model (Model): Trained segmentation model.
num_samples (int): Number of samples to visualize.
target_size (tuple): Target size for resizing (default=(224, 224)).
"""
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
for i in range(num_samples):
# Fetch one image and mask from the generator
image_batch, mask_batch, path = next(generator)
image = image_batch[0] # Single image
ground_truth = mask_batch[0] # Corresponding ground truth mask
# Ensure image is RGB
if len(image.shape) == 2:
image = np.stack([image] * 3, axis=-1) # Convert grayscale to RGB
# Ensure ground truth is a single-channel binary mask
if ground_truth.ndim == 3 and ground_truth.shape[-1] == 3:
ground_truth = ground_truth[:, :, 0] # Extract the first channel
# Generate prediction
raw_prediction = model.predict(image[np.newaxis])[0] # Add batch dimension for prediction
# Ensure prediction is single-channel
if raw_prediction.ndim == 3 and raw_prediction.shape[-1] == 3:
prediction = raw_prediction[:, :, 0] # Extract the first channel
else:
prediction = raw_prediction
prediction = (prediction >= 0.99).astype(np.uint8) # Threshold prediction
# Create overlays
gt_overlay = image.copy()
pred_overlay = image.copy()
# Overlay ground truth in red
gt_overlay[ground_truth == 1] = [255, 0, 0]
# Overlay prediction in green
pred_overlay[prediction == 1] = [0, 255, 0]
# Plot original image, ground truth overlay, and prediction overlay
axes[i, 0].imshow(image)
axes[i, 0].set_title(f"Image {i + 1}")
axes[i, 0].axis('off')
axes[i, 1].imshow(gt_overlay)
axes[i, 1].set_title(f"Ground Truth Overlay {i + 1}")
axes[i, 1].axis('off')
axes[i, 2].imshow(pred_overlay)
axes[i, 2].set_title(f"Prediction Overlay {i + 1}")
axes[i, 2].axis('off')
plt.tight_layout()
plt.show()
# Call the function with test generator and trained model
visualize_predictions(test_generator, model, num_samples=24)
1/1 [==============================] - 0s 466ms/step 1/1 [==============================] - 0s 401ms/step 1/1 [==============================] - 0s 391ms/step 1/1 [==============================] - 0s 380ms/step 1/1 [==============================] - 0s 389ms/step 1/1 [==============================] - 0s 407ms/step 1/1 [==============================] - 0s 401ms/step 1/1 [==============================] - 0s 334ms/step 1/1 [==============================] - 0s 434ms/step 1/1 [==============================] - 0s 427ms/step 1/1 [==============================] - 0s 368ms/step 1/1 [==============================] - 0s 379ms/step 1/1 [==============================] - 1s 550ms/step 1/1 [==============================] - 0s 342ms/step 1/1 [==============================] - 0s 402ms/step 1/1 [==============================] - 0s 454ms/step 1/1 [==============================] - 0s 396ms/step 1/1 [==============================] - 0s 433ms/step 1/1 [==============================] - 0s 421ms/step 1/1 [==============================] - 0s 333ms/step 1/1 [==============================] - 0s 417ms/step 1/1 [==============================] - 0s 398ms/step 1/1 [==============================] - 0s 344ms/step 1/1 [==============================] - 0s 390ms/step
In [8]:
def binary_to_rle(binary_mask):
"""
Converts a binary mask to RLE (Run-Length Encoding).
"""
# Flatten mask in column-major order
flat_mask = binary_mask.T.flatten()
rle = []
start = -1
for i, val in enumerate(flat_mask):
if val == 1 and start == -1:
start = i
elif val == 0 and start != -1:
rle.extend([start + 1, i - start])
start = -1
if start != -1:
rle.extend([start + 1, len(flat_mask) - start])
return ' '.join(map(str, rle))
In [9]:
def save_predictions_to_csv(test_generator, model, output_csv_path):
"""
Generates predictions using the trained model and writes them to a CSV file in RLE format.
Parameters:
test_generator: The data generator for the test set.
model: The trained segmentation model.
output_csv_path: Path to save the CSV file.
"""
with open(output_csv_path, mode='w', newline='') as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(['id', 'segmentation']) # Header row
i = 0
for image, masks, ids in test_generator:
i += 1
if i > 81:
break;
predictions = model.predict(image)
predictions = (predictions > 0.99).astype(int)
for pred_mask, mask_id in zip(predictions, ids):
rle = binary_to_rle(pred_mask.squeeze())
csv_writer.writerow([mask_id, rle])
# Clear memory after each batch just in case
del predictions, image, masks, ids
gc.collect()
In [10]:
predictions_output_path = 'model_output.csv'
save_predictions_to_csv(test_generator, model, predictions_output_path)
1/1 [==============================] - 0s 489ms/step 1/1 [==============================] - 0s 448ms/step 1/1 [==============================] - 0s 394ms/step 1/1 [==============================] - 0s 447ms/step 1/1 [==============================] - 0s 379ms/step 1/1 [==============================] - 0s 465ms/step 1/1 [==============================] - 0s 384ms/step 1/1 [==============================] - 0s 342ms/step 1/1 [==============================] - 0s 457ms/step 1/1 [==============================] - 0s 360ms/step 1/1 [==============================] - 0s 389ms/step 1/1 [==============================] - 0s 457ms/step 1/1 [==============================] - 0s 479ms/step 1/1 [==============================] - 0s 495ms/step 1/1 [==============================] - 1s 519ms/step 1/1 [==============================] - 0s 430ms/step 1/1 [==============================] - 0s 459ms/step 1/1 [==============================] - 0s 409ms/step 1/1 [==============================] - 0s 499ms/step 1/1 [==============================] - 0s 390ms/step 1/1 [==============================] - 0s 380ms/step 1/1 [==============================] - 0s 468ms/step 1/1 [==============================] - 0s 465ms/step 1/1 [==============================] - 0s 400ms/step 1/1 [==============================] - 0s 372ms/step 1/1 [==============================] - 0s 486ms/step 1/1 [==============================] - 0s 342ms/step 1/1 [==============================] - 0s 446ms/step 1/1 [==============================] - 0s 461ms/step 1/1 [==============================] - 0s 394ms/step 1/1 [==============================] - 1s 517ms/step 1/1 [==============================] - 1s 574ms/step 1/1 [==============================] - 0s 417ms/step 1/1 [==============================] - 0s 413ms/step 1/1 [==============================] - 0s 422ms/step 1/1 [==============================] - 0s 452ms/step 1/1 [==============================] - 0s 469ms/step 1/1 [==============================] - 0s 485ms/step 1/1 [==============================] - 0s 476ms/step 1/1 [==============================] - 0s 429ms/step 1/1 [==============================] - 0s 394ms/step 1/1 [==============================] - 0s 395ms/step 1/1 [==============================] - 0s 429ms/step 1/1 [==============================] - 0s 360ms/step 1/1 [==============================] - 0s 388ms/step 1/1 [==============================] - 0s 479ms/step 1/1 [==============================] - 0s 441ms/step 1/1 [==============================] - 1s 502ms/step 1/1 [==============================] - 0s 425ms/step 1/1 [==============================] - 1s 711ms/step 1/1 [==============================] - 1s 530ms/step 1/1 [==============================] - 0s 439ms/step 1/1 [==============================] - 0s 451ms/step 1/1 [==============================] - 0s 436ms/step 1/1 [==============================] - 1s 551ms/step 1/1 [==============================] - 0s 394ms/step 1/1 [==============================] - 1s 528ms/step 1/1 [==============================] - 0s 386ms/step 1/1 [==============================] - 1s 598ms/step 1/1 [==============================] - 1s 588ms/step 1/1 [==============================] - 1s 554ms/step 1/1 [==============================] - 0s 481ms/step 1/1 [==============================] - 0s 354ms/step 1/1 [==============================] - 0s 497ms/step 1/1 [==============================] - 1s 664ms/step 1/1 [==============================] - 1s 787ms/step 1/1 [==============================] - 1s 884ms/step 1/1 [==============================] - 1s 1s/step 1/1 [==============================] - 1s 828ms/step 1/1 [==============================] - 1s 663ms/step 1/1 [==============================] - 1s 895ms/step 1/1 [==============================] - 1s 882ms/step 1/1 [==============================] - 1s 514ms/step 1/1 [==============================] - 0s 499ms/step 1/1 [==============================] - 0s 362ms/step 1/1 [==============================] - 0s 467ms/step 1/1 [==============================] - 0s 477ms/step 1/1 [==============================] - 0s 397ms/step 1/1 [==============================] - 0s 415ms/step 1/1 [==============================] - 0s 484ms/step 1/1 [==============================] - 0s 474ms/step